{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "其实就是去基线的Actor_Critic算法\n",
    "\n",
    "Actor_Critic算法中使用critic模型估计state的价值,也就是估计Q\n",
    "\n",
    "这样估计出来的Q是没有去基线的,而要去基线也非常简单,target-value即可\n",
    "\n",
    "换个角度来想这个问题,target是根据next_state估计出来的,value是根据state估计出来的\n",
    "\n",
    "所以两者的差值可以视为action好坏的衡量,这可以作为actor模型训练的依据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAASAAAADMCAYAAADTcn7NAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAUdElEQVR4nO3db2xT570H8K8dxwaSHIeExV5GLNBtNZrLn20BgtdJq0ZG1uZWZc2LdUJtViFQmYNKM3G16LYgUKUg9oK1Gw0vqkJfXMaUSrS3ES2KTAna6hJImy6kELW67U0K2ObPzXESGtvx+d0XvTmtIbA4f/zE8fcjHYnzPI/t3zmOvxw/x8e2iIiAiEgBq+oCiCh7MYCISBkGEBEpwwAiImUYQESkDAOIiJRhABGRMgwgIlKGAUREyjCAiEgZZQF08OBBLFmyBPPmzUNlZSU6OjpUlUJEiigJoL/+9a9oaGjA7t278eGHH2LVqlWorq5GOBxWUQ4RKWJRcTFqZWUl1qxZgz//+c8AAMMwUFZWhu3bt+P3v/99usshIkVs6X7AWCyGzs5ONDY2mm1WqxVVVVUIBALj3iYajSIajZrrhmHg5s2bKC4uhsVimfGaiSg1IoLBwUGUlpbCar37G620B9D169eRSCTgcrmS2l0uFy5dujTubZqamrBnz550lEdE06i/vx+LFy++a3/aA2gyGhsb0dDQYK7rug6Px4P+/n5omqawMiIaTyQSQVlZGQoKCu45Lu0BtGjRIuTk5CAUCiW1h0IhuN3ucW/jcDjgcDjuaNc0jQFENIv9symStJ8Fs9vtqKiogN/vN9sMw4Df74fX6013OUSkkJK3YA0NDairq8Pq1auxdu1a/PGPf8Tw8DCefvppFeUQkSJKAuhXv/oVrl27hl27diEYDOIHP/gB3n333TsmpoloblPyOaCpikQicDqd0HWdc0BEs9BEX6O8FoyIlGEAEZEyDCAiUoYBRETKMICISBkGEBEpwwAiImUYQESkDAOIiJRhABGRMgwgIlKGAUREyjCAiEgZBhARKcMAIiJlGEBEpAwDiIiUYQARkTIMICJShgFERMowgIhIGQYQESnDACIiZRhARKQMA4iIlGEAEZEyDCAiUoYBRETKpBxAZ86cwaOPPorS0lJYLBa8+eabSf0igl27duG73/0u5s+fj6qqKnz66adJY27evIlNmzZB0zQUFhZi8+bNGBoamtKGEFHmSTmAhoeHsWrVKhw8eHDc/v379+Pll1/GoUOHcPbsWeTl5aG6uhojIyPmmE2bNqGnpwdtbW1obW3FmTNnsHXr1slvBRFlJpkCAHL8+HFz3TAMcbvd8oc//MFsGxgYEIfDIX/5y19EROSTTz4RAHLu3DlzzDvvvCMWi0UuX748ocfVdV0AiK7rUymfiGbIRF+j0zoH9PnnnyMYDKKqqspsczqdqKysRCAQAAAEAgEUFhZi9erV5piqqipYrVacPXt23PuNRqOIRCJJCxFlvmkNoGAwCABwuVxJ7S6Xy+wLBoMoKSlJ6rfZbCgqKjLH3K6pqQlOp9NcysrKprNsIlIkI86CNTY2Qtd1c+nv71ddEhFNg2kNILfbDQAIhUJJ7aFQyOxzu90Ih8NJ/aOjo7h586Y55nYOhwOapiUtRJT5pjWAli5dCrfbDb/fb7ZFIhGcPXsWXq8XAOD1ejEwMIDOzk5zzKlTp2AYBiorK6ezHCKa5Wyp3mBoaAifffaZuf7555+jq6sLRUVF8Hg82LFjB1588UXcf//9WLp0KV544QWUlpZi48aNAIAHHngAv/jFL7BlyxYcOnQI8Xgc9fX1eOKJJ1BaWjptG0ZEGSDV02vvvfeeALhjqaurE5GvT8W/8MIL4nK5xOFwyPr166W3tzfpPm7cuCG//vWvJT8/XzRNk6effloGBwen/RQfEakx0deoRUREYf5NSiQSgdPphK7rnA8imoUm+hrNiLNgRDQ3MYCISBkGEBEpwwAiImUYQESkDAOIiJRhABGRMgwgIlKGAUREyjCAiEgZBhARKcMAIiJlGEBEpAwDiIiUYQARkTIMICJShgFERMowgIhIGQYQESnDACIiZVL+WR6i6SQiGAp+htjQzf9vsUBb/ABy5xcorYvSgwFEigmCH5/EwP/84+tViwUPPPbvDKAswbdgpJQYBozRuOoySBEGECklhgEjwQDKVgwgUkqMUSTiI6rLIEUYQKRUIj6C2OANc92akwtrjl1hRZRODCBSSwTf/nVw27x82DgBnTUYQDSrWKw5sFhzVJdBaZJSADU1NWHNmjUoKChASUkJNm7ciN7e3qQxIyMj8Pl8KC4uRn5+PmpraxEKhZLG9PX1oaamBgsWLEBJSQl27tyJ0dHRqW8NZTwGUHZJKYDa29vh8/nwwQcfoK2tDfF4HBs2bMDw8LA55rnnnsPbb7+NlpYWtLe348qVK3j88cfN/kQigZqaGsRiMbz//vt4/fXXceTIEezatWv6tooyhhgJAN+8BbNYc2DJ4cfTsoZMQTgcFgDS3t4uIiIDAwOSm5srLS0t5piLFy8KAAkEAiIicuLECbFarRIMBs0xzc3NommaRKPRCT2urusCQHRdn0r5NAsMX++T86/6pOPQFuk4tEUutOyVxGhcdVk0RRN9jU5pDkjXdQBAUVERAKCzsxPxeBxVVVXmmGXLlsHj8SAQCAAAAoEAVqxYAZfLZY6prq5GJBJBT0/PuI8TjUYRiUSSFpobEvFY0iQ0LBZYLBZ1BVFaTTqADMPAjh078OCDD2L58uUAgGAwCLvdjsLCwqSxLpcLwWDQHPPt8BnrH+sbT1NTE5xOp7mUlZVNtmyaZYz4CPDtAKKsMukA8vl8uHDhAo4dOzad9YyrsbERuq6bS39//4w/JqXHyEAQYvAERLaa1GxffX09WltbcebMGSxevNhsd7vdiMViGBgYSDoKCoVCcLvd5piOjo6k+xs7SzY25nYOhwMOh2MypdIs9/UkNGWrlI6ARAT19fU4fvw4Tp06haVLlyb1V1RUIDc3F36/32zr7e1FX18fvF4vAMDr9aK7uxvhcNgc09bWBk3TUF5ePpVtoTnAUVAMgHNA2SKlIyCfz4ejR4/irbfeQkFBgTln43Q6MX/+fDidTmzevBkNDQ0oKiqCpmnYvn07vF4v1q1bBwDYsGEDysvL8eSTT2L//v0IBoN4/vnn4fP5eJSTZWScuZ/cBYUAJ6GzRkoB1NzcDAB46KGHktoPHz6M3/zmNwCAAwcOwGq1ora2FtFoFNXV1XjllVfMsTk5OWhtbcW2bdvg9XqRl5eHuro67N27d2pbQhnJSCTP/1htvA4sm1hkvP+GZrlIJAKn0wld16FpmupyaJJEBJfPvYmrH71jtpVWPIrSin/jqfgMN9HXKK8FI6US8WjSOo+AsgsDiBQSGLcHEC/DyCoMIFJHBF/975U7mvn2K3swgEgZEYExGlNdBinEACIiZRhANLtY+CeZTfhskzpiJH0Y0ZJjw7xC1z1uQHMNA4iUMRKjSdeCWSxW5OTOU1gRpRsDiJQRY/S2K+EtsOTkKquH0o8BRMoYozHIt38V1WKB1cYAyiYMIFImNjyA+MiguW6xWGDlEVBWYQCROiJ3fBuixco/yWzCZ5uIlGEAEZEyDCBShl/HSgwgUiZx23VgVpsdFn4SOqvw2SZljPhI0rq9oBhWfhAxqzCASJlE7KukdavVxq/iyDIMIFLm1o0vk9YtObm8GDXL8NkmZeSOL6S3cQ4oy/DZplnDYrXxJ3myDAOIZg2LxcI5oCzDACIlRAQihuoySDEGEKkhBr8PmlL7ZVSiiRIRDA8PY3R0dPx+I4HYyK2ktlgsjoGBgXHHWywWFBQUwMqLVecUBhDNCBHBU089hY6OjnH77TYr/uOJNVixpNhsO3K0Bf+5tWnc8U6nE36/H263e0bqJTUYQDRjrl27hsuXL4/b58xzID9/ET4eeghxw4Gl8/+B//7y73cdf+vWLSQSvHZsrknpeLa5uRkrV66EpmnQNA1erxfvvPPN73qPjIzA5/OhuLgY+fn5qK2tRSgUSrqPvr4+1NTUYMGCBSgpKcHOnTvvephOc5cBBy6NbMDV6H24Hvega3A9rn11998Qp7kppQBavHgx9u3bh87OTpw/fx4/+9nP8Nhjj6GnpwcA8Nxzz+Htt99GS0sL2tvbceXKFTz++OPm7ROJBGpqahCLxfD+++/j9ddfx5EjR7Br167p3Sqa9Qzk4JahAfj6tHvMmAd9hL8Ln3VkihYuXCivvvqqDAwMSG5urrS0tJh9Fy9eFAASCAREROTEiRNitVolGAyaY5qbm0XTNIlGoxN+TF3XBYDouj7V8mmGJBIJ+clPfiIAxl2c+XlyYPch2fviWdnzYoe8tO+/ZM2/Lrvr+IULF8qXX36perNogib6Gp30HFAikUBLSwuGh4fh9XrR2dmJeDyOqqoqc8yyZcvg8XgQCASwbt06BAIBrFixAi7XN7/9VF1djW3btqGnpwc//OEPU6rh0qVLyM/Pn+wm0AwSEdy6deuu/dFYFH9vb4aOf8Go2OGyf4HLwfHnf4Cv/94+/fRT6Lo+E+XSNBsaGprQuJQDqLu7G16vFyMjI8jPz8fx48dRXl6Orq4u2O12FBYWJo13uVwIBoMAgGAwmBQ+Y/1jfXcTjUYRjUbN9UgkAgDQdZ3zR7OUiNxz0ngkNoo3Tn8M4OMJ318kEoHdzrdpmWB4eHhC41IOoO9///vo6uqCrut44403UFdXh/b29pQLTEVTUxP27NlzR3tlZSU0jROXs5FhGCgoKJi2+7PZbKioqMD3vve9abtPmjljBwn/TMqf6rLb7bjvvvtQUVGBpqYmrFq1Ci+99BLcbjdisdgdHyQLhULmZzfcbvcdZ8XG1u/1+Y7Gxkboum4u/f39qZZNRLPQlD9WahgGotEoKioqkJubC7/fb/b19vair68PXq8XAOD1etHd3Y1wOGyOaWtrg6ZpKC8vv+tjOBwO89T/2EJEmS+lt2CNjY14+OGH4fF4MDg4iKNHj+L06dM4efIknE4nNm/ejIaGBhQVFUHTNGzfvh1erxfr1q0DAGzYsAHl5eV48sknsX//fgSDQTz//PPw+XxwOBwzsoFENHulFEDhcBhPPfUUrl69CqfTiZUrV+LkyZP4+c9/DgA4cOAArFYramtrEY1GUV1djVdeecW8fU5ODlpbW7Ft2zZ4vV7k5eWhrq4Oe/fund6tolkhLy9v2o5WeR3Y3GQRue2nKTNAJBKB0+mErut8OzZLiQiuXbuWdPZyKqxWK9xuN3Jycqbl/mhmTfQ1ymvBaEZYLBaUlJSoLoNmOR7TEpEyDCAiUoYBRETKMICISBkGEBEpwwAiImUYQESkDAOIiJRhABGRMgwgIlKGAUREyjCAiEgZBhARKcMAIiJlGEBEpAwDiIiUYQARkTIMICJShgFERMowgIhIGQYQESnDACIiZRhARKQMA4iIlGEAEZEyDCAiUoYBRETKMICISBkGEBEpwwAiImVsqguYDBEBAEQiEcWVENF4xl6bY6/Vu8nIALpx4wYAoKysTHElRHQvg4ODcDqdd+3PyAAqKioCAPT19d1z4yhZJBJBWVkZ+vv7oWma6nIyAvfZ5IgIBgcHUVpaes9xGRlAVuvXU1dOp5N/FJOgaRr3W4q4z1I3kYMDTkITkTIMICJSJiMDyOFwYPfu3XA4HKpLySjcb6njPptZFvln58mIiGZIRh4BEdHcwAAiImUYQESkDAOIiJTJyAA6ePAglixZgnnz5qGyshIdHR2qS1KmqakJa9asQUFBAUpKSrBx40b09vYmjRkZGYHP50NxcTHy8/NRW1uLUCiUNKavrw81NTVYsGABSkpKsHPnToyOjqZzU5TZt28fLBYLduzYYbZxn6WJZJhjx46J3W6X1157TXp6emTLli1SWFgooVBIdWlKVFdXy+HDh+XChQvS1dUljzzyiHg8HhkaGjLHPPPMM1JWViZ+v1/Onz8v69atkx//+Mdm/+joqCxfvlyqqqrko48+khMnTsiiRYuksbFRxSalVUdHhyxZskRWrlwpzz77rNnOfZYeGRdAa9euFZ/PZ64nEgkpLS2VpqYmhVXNHuFwWABIe3u7iIgMDAxIbm6utLS0mGMuXrwoACQQCIiIyIkTJ8RqtUowGDTHNDc3i6ZpEo1G07sBaTQ4OCj333+/tLW1yU9/+lMzgLjP0iej3oLFYjF0dnaiqqrKbLNaraiqqkIgEFBY2eyh6zqAby7Y7ezsRDweT9pny5Ytg8fjMfdZIBDAihUr4HK5zDHV1dWIRCLo6elJY/Xp5fP5UFNTk7RvAO6zdMqoi1GvX7+ORCKR9KQDgMvlwqVLlxRVNXsYhoEdO3bgwQcfxPLlywEAwWAQdrsdhYWFSWNdLheCwaA5Zrx9OtY3Fx07dgwffvghzp07d0cf91n6ZFQA0b35fD5cuHABf/vb31SXMqv19/fj2WefRVtbG+bNm6e6nKyWUW/BFi1ahJycnDvORoRCIbjdbkVVzQ719fVobW3Fe++9h8WLF5vtbrcbsVgMAwMDSeO/vc/cbve4+3Ssb67p7OxEOBzGj370I9hsNthsNrS3t+Pll1+GzWaDy+XiPkuTjAogu92OiooK+P1+s80wDPj9fni9XoWVqSMiqK+vx/Hjx3Hq1CksXbo0qb+iogK5ublJ+6y3txd9fX3mPvN6veju7kY4HDbHtLW1QdM0lJeXp2dD0mj9+vXo7u5GV1eXuaxevRqbNm0y/819liaqZ8FTdezYMXE4HHLkyBH55JNPZOvWrVJYWJh0NiKbbNu2TZxOp5w+fVquXr1qLrdu3TLHPPPMM+LxeOTUqVNy/vx58Xq94vV6zf6xU8obNmyQrq4ueffdd+U73/lOVp1S/vZZMBHus3TJuAASEfnTn/4kHo9H7Ha7rF27Vj744APVJSkDYNzl8OHD5pivvvpKfvvb38rChQtlwYIF8stf/lKuXr2adD9ffPGFPPzwwzJ//nxZtGiR/O53v5N4PJ7mrVHn9gDiPksPfh0HESmTUXNARDS3MICISBkGEBEpwwAiImUYQESkDAOIiJRhABGRMgwgIlKGAUREyjCAiEgZBhARKcMAIiJl/g8qz+Y0M9E5XAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "\n",
    "    def __init__(self):\n",
    "        env = gym.make('CartPole-v1', render_mode='rgb_array')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, terminated, truncated, info = self.env.step(action)\n",
    "        over = terminated or truncated\n",
    "\n",
    "        #限制最大步数\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 200:\n",
    "            over = True\n",
    "        \n",
    "        #没坚持到最后,扣分\n",
    "        if over and self.step_n < 200:\n",
    "            reward = -1000\n",
    "\n",
    "        return state, reward, over\n",
    "\n",
    "    #打印游戏图像\n",
    "    def show(self):\n",
    "        from matplotlib import pyplot as plt\n",
    "        plt.figure(figsize=(3, 3))\n",
    "        plt.imshow(self.env.render())\n",
    "        plt.show()\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()\n",
    "\n",
    "env.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Sequential(\n",
       "   (0): Linear(in_features=4, out_features=64, bias=True)\n",
       "   (1): ReLU()\n",
       "   (2): Linear(in_features=64, out_features=64, bias=True)\n",
       "   (3): ReLU()\n",
       "   (4): Linear(in_features=64, out_features=2, bias=True)\n",
       "   (5): Softmax(dim=1)\n",
       " ),\n",
       " Sequential(\n",
       "   (0): Linear(in_features=4, out_features=64, bias=True)\n",
       "   (1): ReLU()\n",
       "   (2): Linear(in_features=64, out_features=64, bias=True)\n",
       "   (3): ReLU()\n",
       "   (4): Linear(in_features=64, out_features=1, bias=True)\n",
       " ))"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "#演员模型,计算每个动作的概率\n",
    "model_actor = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 2),\n",
    "    torch.nn.Softmax(dim=1),\n",
    ")\n",
    "\n",
    "#评委模型,计算每个状态的价值\n",
    "model_critic = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 1),\n",
    ")\n",
    "\n",
    "model_critic_delay = torch.nn.Sequential(\n",
    "    torch.nn.Linear(4, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 64),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(64, 1),\n",
    ")\n",
    "\n",
    "model_critic_delay.load_state_dict(model_critic.state_dict())\n",
    "\n",
    "model_actor, model_critic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_7627/2154798901.py:34: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at  ../torch/csrc/utils/tensor_new.cpp:201.)\n",
      "  state = torch.FloatTensor(state).reshape(-1, 4)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "-977.0"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython import display\n",
    "import random\n",
    "\n",
    "\n",
    "#玩一局游戏并记录数据\n",
    "def play(show=False):\n",
    "    state = []\n",
    "    action = []\n",
    "    reward = []\n",
    "    next_state = []\n",
    "    over = []\n",
    "\n",
    "    s = env.reset()\n",
    "    o = False\n",
    "    while not o:\n",
    "        #根据概率采样\n",
    "        prob = model_actor(torch.FloatTensor(s).reshape(1, 4))[0].tolist()\n",
    "        a = random.choices(range(2), weights=prob, k=1)[0]\n",
    "\n",
    "        ns, r, o = env.step(a)\n",
    "\n",
    "        state.append(s)\n",
    "        action.append(a)\n",
    "        reward.append(r)\n",
    "        next_state.append(ns)\n",
    "        over.append(o)\n",
    "\n",
    "        s = ns\n",
    "\n",
    "        if show:\n",
    "            display.clear_output(wait=True)\n",
    "            env.show()\n",
    "\n",
    "    state = torch.FloatTensor(state).reshape(-1, 4)\n",
    "    action = torch.LongTensor(action).reshape(-1, 1)\n",
    "    reward = torch.FloatTensor(reward).reshape(-1, 1)\n",
    "    next_state = torch.FloatTensor(next_state).reshape(-1, 4)\n",
    "    over = torch.LongTensor(over).reshape(-1, 1)\n",
    "\n",
    "    return state, action, reward, next_state, over, reward.sum().item()\n",
    "\n",
    "\n",
    "state, action, reward, next_state, over, reward_sum = play()\n",
    "\n",
    "reward_sum"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer_actor = torch.optim.Adam(model_actor.parameters(), lr=1e-3)\n",
    "optimizer_critic = torch.optim.Adam(model_critic.parameters(), lr=1e-2)\n",
    "\n",
    "\n",
    "def requires_grad(model, value):\n",
    "    for param in model.parameters():\n",
    "        param.requires_grad_(value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([24, 1])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def train_critic(state, reward, next_state, over):\n",
    "    requires_grad(model_actor, False)\n",
    "    requires_grad(model_critic, True)\n",
    "\n",
    "    #计算values和targets\n",
    "    value = model_critic(state)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        target = model_critic_delay(next_state)\n",
    "    target = target * 0.99 * (1 - over) + reward\n",
    "\n",
    "    #时序差分误差,也就是tdloss\n",
    "    loss = torch.nn.functional.mse_loss(value, target)\n",
    "\n",
    "    loss.backward()\n",
    "    optimizer_critic.step()\n",
    "    optimizer_critic.zero_grad()\n",
    "\n",
    "    #减去value相当于去基线\n",
    "    return (target - value).detach()\n",
    "\n",
    "\n",
    "value = train_critic(state, reward, next_state, over)\n",
    "\n",
    "value.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-28.23444175720215"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def train_actor(state, action, value):\n",
    "    requires_grad(model_actor, True)\n",
    "    requires_grad(model_critic, False)\n",
    "\n",
    "    #重新计算动作的概率\n",
    "    prob = model_actor(state)\n",
    "    prob = prob.gather(dim=1, index=action)\n",
    "\n",
    "    #根据策略梯度算法的导函数实现\n",
    "    #函数中的Q(state,action),这里使用critic模型估算\n",
    "    prob = (prob + 1e-8).log() * value\n",
    "    loss = -prob.mean()\n",
    "\n",
    "    loss.backward()\n",
    "    optimizer_actor.step()\n",
    "    optimizer_actor.zero_grad()\n",
    "\n",
    "    return loss.item()\n",
    "\n",
    "\n",
    "train_actor(state, action, value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 -45.02870178222656 -980.8\n",
      "100 -1.033004879951477 -314.45\n",
      "200 3.2805840969085693 -694.1\n",
      "300 0.5094745755195618 -210.6\n",
      "400 -2.177739143371582 -730.3\n",
      "500 0.9841430187225342 200.0\n",
      "600 -0.7153235077857971 200.0\n",
      "700 0.1674824208021164 200.0\n",
      "800 -0.20255199074745178 200.0\n",
      "900 0.11738529801368713 200.0\n"
     ]
    }
   ],
   "source": [
    "def train():\n",
    "    model_actor.train()\n",
    "    model_critic.train()\n",
    "\n",
    "    #训练N局\n",
    "    for epoch in range(1000):\n",
    "\n",
    "        #一个epoch最少玩N步\n",
    "        steps = 0\n",
    "        while steps < 200:\n",
    "            state, action, reward, next_state, over, _ = play()\n",
    "            steps += len(state)\n",
    "\n",
    "            #训练两个模型\n",
    "            value = train_critic(state, reward, next_state, over)\n",
    "            loss = train_actor(state, action, value)\n",
    "\n",
    "        #复制参数\n",
    "        for param, param_delay in zip(model_critic.parameters(),\n",
    "                                      model_critic_delay.parameters()):\n",
    "            value = param_delay.data * 0.7 + param.data * 0.3\n",
    "            param_delay.data.copy_(value)\n",
    "\n",
    "        if epoch % 100 == 0:\n",
    "            test_result = sum([play()[-1] for _ in range(20)]) / 20\n",
    "            print(epoch, loss, test_result)\n",
    "\n",
    "\n",
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAASAAAADMCAYAAADTcn7NAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAT8ElEQVR4nO3db2xT570H8K+dxCYhHIckjd0oseBqaDTiz7YA4bRXl6n1yGiExpoX3YTarEJUZQ4qzYTUSC299FZKx3TF1o2GF1OhbzqmTKJTI0pvlJSg3hpS0kU3BIho1SoRYLvAcuykxHbi333R5YBLEuIk9hPT70c6Uv08j+3feWp/Oec5sW0REQERkQJW1QUQ0XcXA4iIlGEAEZEyDCAiUoYBRETKMICISBkGEBEpwwAiImUYQESkDAOIiJRRFkCHDh3CsmXLsGjRIlRVVaGrq0tVKUSkiJIA+utf/4qGhga88sor+PTTT7F27VpUV1cjGAyqKIeIFLGo+DBqVVUV1q9fjz/96U8AgHg8jvLycuzevRsvvvhiusshIkWy0/2E0WgU3d3daGxsNNusVis8Hg98Pt+k94lEIohEIubteDyOmzdvoqioCBaLJeU1E1FyRAThcBilpaWwWqc+0Up7AF2/fh3j4+NwOp0J7U6nE5cuXZr0Pk1NTdi/f386yiOieTQ4OIiysrIp+9MeQLPR2NiIhoYG87ZhGHC73RgcHISmaQorI6LJhEIhlJeXY8mSJdOOS3sAFRcXIysrC4FAIKE9EAjA5XJNeh+73Q673X5Xu6ZpDCCiBexeSyRpvwpms9lQWVmJ9vZ2sy0ej6O9vR26rqe7HCJSSMkpWENDA+rq6rBu3Tps2LABv//97zEyMoJnnnlGRTlEpIiSAHryySfx1VdfYd++ffD7/fjBD36AkydP3rUwTUT3NyV/BzRXoVAIDocDhmFwDYhoAZrpe5SfBSMiZRhARKQMA4iIlGEAEZEyDCAiUoYBRETKMICISBkGEBEpwwAiImUYQESkDAOIiJRhABGRMgwgIlKGAUREyjCAiEgZBhARKcMAIiJlGEBEpAwDiIiUYQARkTIMICJShgFERMowgIhIGQYQESnDACIiZRhARKQMA4iIlGEAEZEySQfQ6dOnsXXrVpSWlsJiseDdd99N6BcR7Nu3Dw8++CByc3Ph8Xhw+fLlhDE3b97E9u3boWkaCgoKsGPHDgwPD89pR4go8yQdQCMjI1i7di0OHTo0af+BAwfwxhtv4PDhwzh79iwWL16M6upqjI6OmmO2b9+Ovr4+tLW1obW1FadPn8azzz47+70goswkcwBAjh8/bt6Ox+Picrnkd7/7ndk2NDQkdrtd/vKXv4iIyIULFwSAfPLJJ+aY999/XywWi1y5cmVGz2sYhgAQwzDmUj4RpchM36Pzugb0xRdfwO/3w+PxmG0OhwNVVVXw+XwAAJ/Ph4KCAqxbt84c4/F4YLVacfbs2UkfNxKJIBQKJWxElPnmNYD8fj8AwOl0JrQ7nU6zz+/3o6SkJKE/OzsbhYWF5phva2pqgsPhMLfy8vL5LJuIFMmIq2CNjY0wDMPcBgcHVZdERPNgXgPI5XIBAAKBQEJ7IBAw+1wuF4LBYEL/2NgYbt68aY75NrvdDk3TEjYiynzzGkDLly+Hy+VCe3u72RYKhXD27Fnoug4A0HUdQ0ND6O7uNsd0dHQgHo+jqqpqPsshogUuO9k7DA8P47PPPjNvf/HFF+jp6UFhYSHcbjf27NmD1157DStWrMDy5cvx8ssvo7S0FNu2bQMAPPTQQ/jpT3+KnTt34vDhw4jFYqivr8cvfvELlJaWztuOEVEGSPby2ocffigA7trq6upE5JtL8S+//LI4nU6x2+3y2GOPSX9/f8Jj3LhxQ375y19Kfn6+aJomzzzzjITD4Xm/xEdEasz0PWoREVGYf7MSCoXgcDhgGAbXg4gWoJm+RzPiKhgR3Z8YQESkDAOIiJRhABGRMgwgIlKGAUREyjCAiEgZBhARKcMAIiJlGEBEpAwDiIiUYQARkTIMICJShgFERMowgIhIGQYQESnDACIiZRhARKQMA4iIlGEAEZEySf8sD9FMiQhGvvoSEeP2D1Ha8gux5MEVCquihYQBRCn11YXTuN7/v+btpf9WyQAiE0/BKKVE4qpLoAWMAUQpJAADiKbBAKLUEUDiDCCaGgOIUkh4CkbTYgBRygjAUzCaFgOIUkcEIqK6ClrAkgqgpqYmrF+/HkuWLEFJSQm2bduG/v7+hDGjo6Pwer0oKipCfn4+amtrEQgEEsYMDAygpqYGeXl5KCkpwd69ezE2Njb3vaEFh2tANJ2kAqizsxNerxdnzpxBW1sbYrEYNm/ejJGREXPMCy+8gPfeew8tLS3o7OzE1atX8cQTT5j94+PjqKmpQTQaxccff4y3334bR48exb59++Zvr2iB4FUwugeZg2AwKACks7NTRESGhoYkJydHWlpazDEXL14UAOLz+URE5MSJE2K1WsXv95tjmpubRdM0iUQiM3pewzAEgBiGMZfyKcXGY1G51HpQug7vNLfL/3NYdVmUBjN9j85pDcgwDABAYWEhAKC7uxuxWAwej8ccs3LlSrjdbvh8PgCAz+fD6tWr4XQ6zTHV1dUIhULo6+ub9HkikQhCoVDCRplAeApG05p1AMXjcezZswePPPIIVq1aBQDw+/2w2WwoKChIGOt0OuH3+80xd4bPRP9E32SamprgcDjMrby8fLZlUxqJCCSeuLZnsfC6B90261eD1+vF+fPncezYsfmsZ1KNjY0wDMPcBgcHU/6cNHfx8RgioesJbbmFpYqqoYVoVh9Gra+vR2trK06fPo2ysjKz3eVyIRqNYmhoKOEoKBAIwOVymWO6uroSHm/iKtnEmG+z2+2w2+2zKZVUEoHIeEKTNcumqBhaiJI6AhIR1NfX4/jx4+jo6MDy5csT+isrK5GTk4P29nazrb+/HwMDA9B1HQCg6zp6e3sRDN7+ioa2tjZomoaKioq57AtlAitPwei2pI6AvF4v3nnnHfz973/HkiVLzDUbh8OB3NxcOBwO7NixAw0NDSgsLISmadi9ezd0XcfGjRsBAJs3b0ZFRQWeeuopHDhwAH6/Hy+99BK8Xi+Pcr4DLAwgukNSAdTc3AwA+PGPf5zQfuTIEfzqV78CABw8eBBWqxW1tbWIRCKorq7Gm2++aY7NyspCa2srdu3aBV3XsXjxYtTV1eHVV1+d255QRrBYslSXQAuIRSTz/lY+FArB4XDAMAxomqa6HJpC7FYY51v+E2O3wmbbsk1P44GV/66wKkqHmb5HeTxMaWWx8giIbmMAUVrx74DoTnw1UFrxCIjuxACi9OIREN2BrwZKKx4B0Z0YQJRW/DsguhNfDZRC8q/vZb3NYrGoKYUWJAYQpYzEBXclENEdGECUOvw2RLoHBhCljEicX0pP02IAUcrwN8HoXhhAlDrxOLgGRNNhAFHKiMQBnoLRNBhAlDI8BaN7YQBR6sS5CE3TYwBRyvAIiO6FAUQp800A8QiIpsYAopSR+Pgki9D8KAbdxgCilImEriM+FjVvZ9nykLO4QF1BtOAwgChlJJ74m2AWqxXWrFn9FB3dpxhAlEYWfiEZJeCrgdLHYuHXcVACBhCljcXCIyBKxFcDpRGPgCgRVwRp1sbHxxEOh6fsv3Xr64TbIoJQaBjZsclDKCcnB4sXL57XGmlhYwDRrF2+fBlbtmxBLBabtP/RtWWo37rGvH3N78f2TZsQ+jo66fitW7eaP/9N3w0MIJq1sbExXL16FdHo5IHyT3c+AtHlGBh9CHlZIeRHT+HK1asIjUQmH//Pf6ayXFqAkloDam5uxpo1a6BpGjRNg67reP/9983+0dFReL1eFBUVIT8/H7W1tQgEAgmPMTAwgJqaGuTl5aGkpAR79+7F2NjY/OwNLSg3YqXoHd6EG7EyDI4+hP8L/wfG4lwDotuSCqCysjK8/vrr6O7uxrlz5/Doo4/iZz/7Gfr6+gAAL7zwAt577z20tLSgs7MTV69exRNPPGHef3x8HDU1NYhGo/j444/x9ttv4+jRo9i3b9/87hUtCKPxfIyJ7V+3LBgec2Ccn0+lO8kcLV26VP785z/L0NCQ5OTkSEtLi9l38eJFASA+n09ERE6cOCFWq1X8fr85prm5WTRNk0gkMuPnNAxDAIhhGHMtn+agt7dXbDbbxE9f3LU9umG9/HfTSdn/Wpf812s++e2LvxW7LWfK8U8++aTqXaJ5MtP36KzXgMbHx9HS0oKRkRHouo7u7m7EYjF4PB5zzMqVK+F2u+Hz+bBx40b4fD6sXr0aTqfTHFNdXY1du3ahr68PP/zhD5Oq4dKlS8jPz5/tLtAcff7559N+38/Alc9xpuM1BKNu5FqHYY99Nu3ptmEYuHDhQipKpTQbHh6e0bikA6i3txe6rmN0dBT5+fk4fvw4Kioq0NPTA5vNhoKCgoTxTqcTfr8fAOD3+xPCZ6J/om8qkUgEkcjthctQKATgmxcs14/UCYfD0wbQZ1du4rMrH8348WKxGIaGhuahMlJtZGRkRuOSDqDvf//76OnpgWEY+Nvf/oa6ujp0dnYmXWAympqasH///rvaq6qqoGlaSp+bpqZpGqzz+FPLxcXFePjhh+ft8UidiYOEe0n61WOz2fC9730PlZWVaGpqwtq1a/GHP/wBLpcL0Wj0rn/BAoEAXC4XAMDlct11VWzi9sSYyTQ2NsIwDHMbHBxMtmwiWoDm/M9XPB5HJBJBZWUlcnJy0N7ebvb19/djYGAAuq4DAHRdR29vL4LBoDmmra0NmqahoqJiyuew2+3mpf+JjYgyX1KnYI2NjdiyZQvcbjfC4TDeeecdnDp1Ch988AEcDgd27NiBhoYGFBYWQtM07N69G7quY+PGjQCAzZs3o6KiAk899RQOHDgAv9+Pl156CV6vF3a7PSU7SEQLV1IBFAwG8fTTT+PatWtwOBxYs2YNPvjgA/zkJz8BABw8eBBWqxW1tbWIRCKorq7Gm2++ad4/KysLra2t2LVrF3Rdx+LFi1FXV4dXX311fveK0sJqtULTtCn/EjpZubm58/I4lDksMt1ljAUqFArB4XDAMAyejikUi8WmvXqZrLy8PBQVFc3b45E6M32P8rNgNGs5OTkoLy9XXQZlMH4fEBEpwwAiImUYQESkDAOIiJRhABGRMgwgIlKGAUREyjCAiEgZBhARKcMAIiJlGEBEpAwDiIiUYQARkTIMICJShgFERMowgIhIGQYQESnDACIiZRhARKQMA4iIlGEAEZEyDCAiUoYBRETKMICISBkGEBEpwwAiImUYQESkDAOIiJRhABGRMgwgIlImW3UBsyEiAIBQKKS4EiKazMR7c+K9OpWMDKAbN24AAMrLyxVXQkTTCYfDcDgcU/ZnZAAVFhYCAAYGBqbdOUoUCoVQXl6OwcFBaJqmupyMwDmbHRFBOBxGaWnptOMyMoCs1m+WrhwOB18Us6BpGuctSZyz5M3k4ICL0ESkDAOIiJTJyACy2+145ZVXYLfbVZeSUThvyeOcpZZF7nWdjIgoRTLyCIiI7g8MICJShgFERMowgIhImYwMoEOHDmHZsmVYtGgRqqqq0NXVpbokZZqamrB+/XosWbIEJSUl2LZtG/r7+xPGjI6Owuv1oqioCPn5+aitrUUgEEgYMzAwgJqaGuTl5aGkpAR79+7F2NhYOndFmddffx0WiwV79uwx2zhnaSIZ5tixY2Kz2eStt96Svr4+2blzpxQUFEggEFBdmhLV1dVy5MgROX/+vPT09Mjjjz8ubrdbhoeHzTHPPfeclJeXS3t7u5w7d042btwoDz/8sNk/NjYmq1atEo/HI//4xz/kxIkTUlxcLI2NjSp2Ka26urpk2bJlsmbNGnn++efNds5ZemRcAG3YsEG8Xq95e3x8XEpLS6WpqUlhVQtHMBgUANLZ2SkiIkNDQ5KTkyMtLS3mmIsXLwoA8fl8IiJy4sQJsVqt4vf7zTHNzc2iaZpEIpH07kAahcNhWbFihbS1tcmmTZvMAOKcpU9GnYJFo1F0d3fD4/GYbVarFR6PBz6fT2FlC4dhGABuf2C3u7sbsVgsYc5WrlwJt9ttzpnP58Pq1avhdDrNMdXV1QiFQujr60tj9enl9XpRU1OTMDcA5yydMurDqNevX8f4+HjC/3QAcDqduHTpkqKqFo54PI49e/bgkUcewapVqwAAfr8fNpsNBQUFCWOdTif8fr85ZrI5nei7Hx07dgyffvopPvnkk7v6OGfpk1EBRNPzer04f/48PvroI9WlLGiDg4N4/vnn0dbWhkWLFqku5zsto07BiouLkZWVddfViEAgAJfLpaiqhaG+vh6tra348MMPUVZWZra7XC5Eo1EMDQ0ljL9zzlwu16RzOtF3v+nu7kYwGMSPfvQjZGdnIzs7G52dnXjjjTeQnZ0Np9PJOUuTjAogm82GyspKtLe3m23xeBzt7e3QdV1hZeqICOrr63H8+HF0dHRg+fLlCf2VlZXIyclJmLP+/n4MDAyYc6brOnp7exEMBs0xbW1t0DQNFRUV6dmRNHrsscfQ29uLnp4ec1u3bh22b99u/jfnLE1Ur4In69ixY2K32+Xo0aNy4cIFefbZZ6WgoCDhasR3ya5du8ThcMipU6fk2rVr5vb111+bY5577jlxu93S0dEh586dE13XRdd1s3/ikvLmzZulp6dHTp48KQ888MB36pLynVfBRDhn6ZJxASQi8sc//lHcbrfYbDbZsGGDnDlzRnVJygCYdDty5Ig55tatW/LrX/9ali5dKnl5efLzn/9crl27lvA4X375pWzZskVyc3OluLhYfvOb30gsFkvz3qjz7QDinKUHv46DiJTJqDUgIrq/MICISBkGEBEpwwAiImUYQESkDAOIiJRhABGRMgwgIlKGAUREyjCAiEgZBhARKcMAIiJl/h9NRLzTO+PXiAAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 300x300 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "200.0"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "play(True)[-1]"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "第9章-策略梯度算法.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python [conda env:pt39]",
   "language": "python",
   "name": "conda-env-pt39-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
